library(MASS)
library(randomForest)
library(KRLS)
library(grf)
library(AER)
library(glmnet)
library(CBPS)
library(sandwich)
library(arm)
library(sensemakr)
library(PLCE)
library(tictoc)
library(ranger)

rm(list=ls())
ran.num<-round(runif(1)*1e8)

#source('~/Dropbox/InfluenceFunctions/Code/HOE.R')
#source('../Code/HOE_APSR.R')
# devtools::load_all('~/Dropbox/Github/PLCE')
tic("Whole thing")
theta.run<-se.run<-cover.run<-NULL
n<-1000
k<-5
meanhet<-errhet<-inter<-T
inter <- T
i.REs<-T
treathet <- T
 for(i.n in c(250,500,750, 1000,2000)){
#for(i.n in c(1000)){
  for(i.k in c(5) ){
    # for(meanhet in c(T,F)){
      # for(errhet in c(T,F)){
      for(treathet in c(T,F)){
        for(inter in c(T,F)){
          for(i.REs in c(T,F)){
            meanhet <- errhet <- treathet
          n<-i.n
          k<-i.k
          
          ran.num<-round(runif(1)*1e8)
          set.seed(ran.num)
          print(ran.num)
          print(c(meanhet,errhet,inter,i.REs))

          
          ##Format data ----
          var.mat <- diag(k)
          var.mat[var.mat==0] <- 0.5
          # covariates
          X <- MASS::mvrnorm(n, rep(0, k), Sig = var.mat)
          X<-apply(X,2,scale)
          X[,1]<-X[,1]/(mean(X[,1]^2)^.5)
          X<-apply(X,2,FUN=function(x) x/(mean(x^2))^.5)
          
          
          ## REs ----
          ids.map<-sample(as.factor(letters[1:20]),n,T)
          ids.map <- as.factor(ceiling(sort(X[,1],ind=T)$ix/20))
          ids.map <- as.factor(ceiling(sort(rnorm(n),ind=T)$ix/20))
          
          res.map <- rnorm(length(unique(ids.map)))
          names(res.map)<-sort(unique(ids.map))
          res.true <-  (res.map[ids.map])*i.REs
          X.REs<-model.matrix(~ids.map-1)
          
          if(errhet)	sd.use<-((1+X[,1]+(length(unique(X[,1]))==2) )^2/2)^.5 else sd.use<-1
          errst<-rnorm(n,sd=sd.use)#rt(n,8)*sd.use#
          treat<-(X[,1])+errst+res.true
          errsy<-rnorm(n,sd=1)#rt(n,8)*sd.use#
          if(meanhet) Y<-1+treat*(X[,1]^2)+errsy else Y<-1+treat+(X[,1]^2)+errsy
          Y <- Y + res.true
          ##Shift X
          if(inter){
            m1<-exp(-5*(outer(X[,1],X[,1],"-")^2))
            diag(m1)<-0
            oY<-m1%*%(X[,1]^2)/rowSums(m1)
            oT<-m1%*%(X[,1])/rowSums(m1)
            
              oY<-m1%*%(treat)/rowSums(m1)
              oT<-m1%*%(X[,1])/rowSums(m1)
            
            
            Y<-Y+oY
            treat<-treat-oT
          }
          lm(Y~treat+X)
          
          X2.model<-X
          X2.model[,1]<-X[,1]-.5*X[,2]
          X2.model[,2]<-X[,2]-.5*X[,1]
          ## Rotates X's
          X<-cbind(X2.model)
          
          X.all <- cbind(X,model.matrix(~ids.map-1))
          X.all <- X.all[,apply(X.all,2,sd)>0]

          ##Fit random forest, DML, HOE, OLS, CBPS, LCE, KRLS ----
          dml.points<-dml.ses<-NULL
          dml1<-NULL
          est.dml<-function(z,Y0=Y,treat0=treat,X0=data.frame(X,X.REs)) PLCE:::DML(Y0,treat0,X0)
          dml.all<-sapply(1:10,FUN=est.dml)
          dml1$point<-median(unlist(dml.all[1,]))
          dml1$se<-(var(unlist(dml.all[1,]))+mean(unlist(dml.all[2,])^2))^.5
          
          X.cb<-cbind(X,X^2,X.REs)
          X.cb<-X.cb[,apply(X.cb,2,sd)>0]
          X.cb<-X.cb[,PLCE:::check.cor(X.cb,thresh=1e-3)$k]
          cb1<-CBPS(treat~X.cb,method="exact")
          X.cb<-X.cb[,!is.na(lm(Y~treat+X.cb,w=cb1$w)$coef[-c(1:2)])]
          lm.cb<-lm(Y~treat+X.cb,w=cb1$w)
          v1<-vcov_outcome(object=cb1,Y=Y,Z=cbind(1,treat,X.cb),lm.cb$coef)
          
          #c0<-causal_forest(X=cbind(X,X.REs),Y=Y,W=treat, num.trees = 4000)
          # ape0<-average_treatment_effect(c0);ape0
          # z0<-abs(ape0[1]-1)/ape0[2]
          # c1<-causal_forest(X=cbind(X,X.REs),Y=Y,W=treat, num.trees = 4000,cluster=ids.map)
          # ape1<-average_treatment_effect(c1);ape1
          # z1<-abs(ape1[1]-1)/ape1[2]
          # c2<-causal_forest(X=cbind(X),Y=Y,W=treat, num.trees = 4000,cluster=ids.map)
          # ape2<-average_treatment_effect(c2);ape2;(ape2[1]-1)/ape2[2]
          # z2<-abs(ape2[1]-1)/ape2[2]
          # 
          # best.grf <- which.min(c(z0,z1,z2))
          # if(best.grf==1) c1.rf <- c0
          # if(best.grf==2) c1.rf <- c1
          # if(best.grf==3) c1.rf <- c2
          # 
          
          c1.rf<-causal_forest(X=cbind(X,X.REs),Y=Y,W=treat, num.trees = 4000)
          ape<-average_treatment_effect(c1.rf);ape;(ape[1]-1)/ape[2]
          
          ##Fit HOE---
          t1<-tic("overall")
          h1<-plce(Y,treat,X,id=ids.map,num.fit=50, var.type="HC2");h1$point;h1$se
          t2<-toc()
          
          
          ##Fit OLS ----
          lm.ols<-lm(Y~treat+X)
          
          ##Fit KRLS ----
          #b1<-krls(y=Y,X=cbind(treat,X,X.REs),which.derivative=c(1,2))
          k1<-KRLS2::krls(y=Y,X=cbind(treat,X,X.REs), epsilon=0.001)
          b1 <- KRLS2:::summary.krls2(k1)
          b1$sumavgderiv[1,1:2]
          
          ##Fit Linear Parametric Model ----
          
          theta.dml<-dml1$point
          theta.hoe<-h1$point
          theta.rf<-ape[1]
          theta.cbps<-lm.cb$coef[2]
          theta.ols<-lm.ols$coef[2]
          theta.krls<-b1$avgderivatives[1]
          
          se.dml<-dml1$se
          se.hoe<-h1$se
          se.jk<-h1$jk
          se.rf<-ape[2]#(var(c1$pred)/n+mean(c1$var/n))^.5
          se.cbps<- v1[2,2]^.5
          se.ols<- vcovHC(lm.ols,"HC3")[2,2]^.5
          se.krls<-(b1$var.a[1])^.5
          
          
          
          theta.curr<-c(i.n, i.k, meanhet,errhet,inter,i.REs,
                        theta.dml,theta.hoe,theta.rf,theta.cbps,theta.ols, theta.krls
          )
          
          se.curr<-c(i.n, i.k, meanhet,errhet,inter,i.REs,
                     se.dml,se.hoe,se.rf,se.cbps,se.ols, se.krls
          )
          
          cover.curr<-c(i.n, i.k, meanhet,errhet,inter,i.REs,
                        abs(theta.dml-1)<1.64*se.dml,
                        abs(theta.hoe-1)<1.64*se.hoe,
                        abs(theta.rf-1)<1.64*se.rf,
                        abs(theta.cbps-1)<1.64*se.cbps,
                        abs(theta.ols-1)<1.64*se.ols,
                        abs(theta.krls-1)<1.64*se.krls
            )
          
          theta.run<-rbind(theta.run,theta.curr)
          se.run<-rbind(se.run,se.curr)
          cover.run<-rbind(cover.run,cover.curr)
          
          for(i.print in 1:5) cat("#######################################\n")
          cat(c("Done with",c(meanhet,errhet,inter,i.REs,n,k),"\n"))
          for(i.print in 1:5) cat("#######################################\n")
          
          
        }}}}}#}

toc()

colnames(theta.run)<-colnames(se.run)<-colnames(cover.run)<-c("n","k","meanhet","errhet","inter","REs",
                                                              "DML","HOE","RF","CBPS","OLS", "KRLS")


output<-list("theta"=theta.run,"se"=se.run,"cover"=cover.run)

name.save<-paste("output",ran.num,sep="_")
save(output,file=name.save)

